import os, csv, json
import random
import argparse
import numpy as np
from tqdm import tqdm
import time
from utils.SAM import SAM
# torch modules
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from torch.autograd.functional import hessian
from torch.autograd import grad
from utils.hvp import hessian_vector_product

# custom libs
from utils.datasets import load_dataset
from utils.networks import load_network, load_trained_network
from utils.optims import define_loss_function, define_optimizer


def train(args, epoch, net, train_loader, taskloss, optimizer, scheduler):

    # compute train loss
    train_corr = 0
    train_loss = 0.
    # train...
    net.train()
    if scheduler: scheduler.step()
    for batch_idx, (data, labels) in enumerate( \
        tqdm(train_loader, desc='[{}/{}]'.format(epoch, args.epoch))):
        if args.network=='LR':
            data = Variable(data.view(-1, 28*28))
        if args.cuda:
            data, labels = data.cuda(), labels.cuda()
        data, labels = Variable(data), Variable(labels)
        optimizer.zero_grad()
        output = net(data)

        # : compute loss value (default: element-wise mean)
        bsize = data.size()[0]
        tloss = taskloss(output, labels)
        train_loss += (tloss.data.item() * bsize)
        tloss.backward()
        optimizer.step()
        
        # : compute the accuracy
        predict = output.data.max(1, keepdim=True)[1]
        train_corr += predict.eq(labels.data.view_as(predict)).cpu().sum().item()

    # update the losses
    train_loss /= len(train_loader.dataset)
    train_acc   = 100. * train_corr / len(train_loader.dataset)

    # return acc and loss
    return train_acc, train_loss

def percentile(list):
    return np.percentile(list, 50)


def print_param(net):
    for name, param in net.named_parameters():
      if param.requires_grad:
        if name =='conv1.weight': 
          print(f"Gradient of parameter before {name}: {param}")


def rademacher(shape, dtype=torch.float32):
    """Sample from Rademacher distribution."""
    rand = ((torch.rand(shape) < 0.5)) * 2 - 1
    return rand.to(dtype).cuda()

def estimate_trace(tloss, net, V):
    trace = 0

    for _ in range(V):
        vec = [rademacher(p.shape) for p in net.parameters()]
        Hvec = hessian_vector_product(tloss, list(net.parameters()), vec)

        for v, Hv in zip(vec, Hvec):
            vHv = torch.einsum("i,i->", v.flatten(), Hv.flatten())
            # trace += vHv / V
            trace += vHv
    trace = trace / V

    return trace

def train_hessTrace(args, epoch, net, train_loader, taskloss, optimizer, scheduler):
    print("HHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHH")
    # compute train loss
    train_corr = 0
    train_loss = 0.
    trace_values = []
    # Initialize variables for normalized trace
    normalized_trace_values = []
    tau= 0.
    batch_trace = 0.
    
    
    net.train()
    if scheduler: 
        scheduler.step()

    for batch_idx, (data, labels) in enumerate(tqdm(train_loader, desc='[{}/{}]'.format(epoch, args.epoch))):

        
        if args.network == 'LR':
            data = data.view(-1, 28*28)
        if args.cuda:
            data, labels = data.cuda(), labels.cuda()
        
        optimizer.zero_grad()
        output = net(data)

        # Compute loss value (default: element-wise mean)
        bsize = data.size()[0]
        tloss = taskloss(output, labels)
        train_loss += (tloss.item() * bsize)

       
        trace = estimate_trace(tloss, net, V=1)
        # print(trace.item())
        trace_values.append(trace.item())

        # tau = percentile(trace_values)

        median =  percentile(trace_values)

        # Ensure trace is a tensor and subtract tau directly
        trace_sub_tau = trace - median

        regularization_term = torch.max(torch.tensor(0.0).to(trace.device), trace_sub_tau)
        
        # print(regularization_term._grad_fn)
        # print(regularization_term.requires_grad)
        # print(regularization_term.is_leaf)

        batch_trace += trace.item() 

        
        if (regularization_term.item() > tau):
            tloss = tloss + 0.001 * trace.item()  # Adjust the factor as per your requirement

        else: 
            tloss = tloss
            tau = median

        # print(tloss._grad_fn)
        # print(tloss.requires_grad)
        # print(tloss.is_leaf)

        optimizer.zero_grad()
        tloss.backward()
        optimizer.step()

        # Compute accuracy
        predict = output.max(1, keepdim=True)[1]
        train_corr += predict.eq(labels.view_as(predict)).sum().item()


    # Update the losses
    train_loss /= len(train_loader.dataset)
    train_acc = 100. * train_corr / len(train_loader.dataset)
    batch_trace /= len(train_loader.dataset)

    # Return accuracy, loss, and possibly the last normalized trace value
    return train_acc, train_loss, batch_trace


def train_hessTrace_SAM(args, epoch, net, train_loader, taskloss, base_optimizer, scheduler):
    print("HHHHHHHHHfsdfsdfsfseHHHHHHHHHHHHHHHHH")
    # compute train loss
    train_corr = 0
    train_loss = 0.
    
    base_optim = torch.optim.SGD
    print(f"learning rate = {args.lr}")

    batch_trace = 0.
    optimizer_sam = SAM(net.parameters(), base_optim, lr=args.lr, momentum=0.9)
    
    net.train()
    if scheduler: 
        scheduler.step()

    for batch_idx, (data, labels) in enumerate(tqdm(train_loader, desc='[{}/{}]'.format(epoch, args.epoch))):
        if args.network == 'LR':
            data = data.view(-1, 28*28)
        if args.cuda:
            data, labels = data.cuda(), labels.cuda()
        
        optimizer_sam.zero_grad()
        output = net(data)

        # Compute loss value (default: element-wise mean)
        bsize = data.size()[0]
        tloss = taskloss(output, labels)
        train_loss += (tloss.item() * bsize)
        
        #first update step
        tloss.backward(create_graph=True, retain_graph=True)

        optimizer_sam.first_step(zero_grad=True)
        
        
        taskloss(output, labels).backward()

        optimizer_sam.second_step(zero_grad=True)

        # Compute accuracy
        predict = output.max(1, keepdim=True)[1]
        train_corr += predict.eq(labels.view_as(predict)).sum().item()


    # Update the losses
    train_loss /= len(train_loader.dataset)
    train_acc = 100. * train_corr / len(train_loader.dataset)

    # Return accuracy, loss, and possibly the last normalized trace value
    return train_acc, train_loss


def train_hessTrace_Normalized(args, epoch, net, train_loader, taskloss, optimizer, scheduler):
    print("Training Epoch: {}".format(epoch))

    # Initialize min and max trace values for dynamic normalization
    min_trace_val = float('inf')
    max_trace_val = float('-inf')

    # Initialize variables for loss and accuracy
    train_loss = 0.
    train_corr = 0
    batch_trace = 0.

    net.train()
    if scheduler: 
        scheduler.step()

    for batch_idx, (data, labels) in enumerate(tqdm(train_loader, desc='[{}/{}]'.format(epoch, args.epoch))):
        if args.network == 'LR':
            data = data.view(-1, 28*28)
        if args.cuda:
            data, labels = data.cuda(), labels.cuda()

        optimizer.zero_grad()
        output = net(data)
        loss = taskloss(output, labels)
        
        # Calculate batch size for accurate loss averaging
        bsize = data.size(0)
        train_loss += (loss.item() * bsize)

        # Estimate trace of Hessian
        trace = estimate_trace(loss, net, V=1)
        batch_trace += trace.item()

        min_trace_val = min(min_trace_val, trace.item())
        max_trace_val = max(max_trace_val, trace.item())

        # Normalize trace using updated min and max
        if max_trace_val > min_trace_val:
            normalized_trace = (trace - min_trace_val) / (max_trace_val - min_trace_val)
        else:
            normalized_trace = torch.tensor(0.0)  # Use zero if max equals min
        
        # print(normalized_trace._grad_fn)
        # print(normalized_trace.requires_grad)
        # print(normalized_trace.is_leaf)

        # Adjust loss with normalized trace
        adjusted_loss = loss + 0.001 * normalized_trace
        adjusted_loss.backward()
        optimizer.step()

        # Compute and accumulate accuracy
        predict = output.max(1, keepdim=True)[1]
        train_corr += predict.eq(labels.view_as(predict)).sum().item()

    # Calculate overall training loss and accuracy
    train_loss /= len(train_loader.dataset)
    train_acc = 100. * train_corr / len(train_loader.dataset)
    batch_trace /= len(train_loader.dataset)
    
    return train_acc, train_loss, batch_trace

def train_with_hessian(args, epoch, net, train_loader, taskloss, optimizer, scheduler):
    # compute train loss
    train_corr = 0
    train_loss = 0.
    l2_norm_total = 0.  # Initialize total loss with the hessian L2-norm
    total_loss = 0.

    # train...
    net.train()
    if scheduler: scheduler.step()
    for batch_idx, (data, labels) in enumerate(train_loader):
        if args.cuda:
            data, labels = data.cuda(), labels.cuda()
        data, labels = Variable(data), Variable(labels)
        optimizer.zero_grad()

        # Forward pass
        output = net(data)

        # compute task loss value
        tloss = taskloss(output, labels)
        bsize = data.size()[0]
        train_loss += (tloss.data.item() * bsize)

        # compute the accuracy
        predict = output.data.max(1, keepdim=True)[1]
        train_corr += predict.eq(labels.data.view_as(predict)).cpu().sum().item()

        
        """
        Compute the first-order derivatives of the loss with respect to the model parameters.
        Returns:
            torch.Tensor: A 1D tensor representing the gradient vector ∇θL.
        """
        gradients = torch.autograd.grad(tloss, net.parameters(), create_graph=True)

        l2_norm_total = 0.0  # Initialize total loss with the hessian L2-norm

        layer_hessian_values = []
        # Compute layer-wise Hessian-vector product
        for name, param in net.named_parameters():
            if param.requires_grad:
                # print(name)
                # print(param)
                # Find the index of the parameter in the model's parameter list
                param_index = list(net.named_parameters()).index((name, param))

                # Flatten the gradients
                flat_grads = gradients[param_index].view(-1)

                # Create a random vector with the same shape as the flattened gradients
                random_vector = torch.randn_like(flat_grads)

                # Compute Hessian-vector product
                hvp = torch.autograd.grad(flat_grads, param, grad_outputs=random_vector, retain_graph=True)

                # Calculate the L2-norm of the HVP
                l2_norm_hvp = torch.sqrt(torch.sum(hvp[0]**2))

                layer_hessian_values.append(l2_norm_hvp)

                l2_norm_total += l2_norm_hvp
               

        # Combine the original task loss with the Hessian loss
        total_loss = tloss + args.hessLR * l2_norm_total
       

        # Backward pass for the combined loss
        total_loss.backward()

        # Perform optimization step using the gradients of the combined loss
        optimizer.step()
        

       
    # update the losses
    train_loss /= len(train_loader.dataset)
    train_acc = 100. * train_corr / len(train_loader.dataset)
    total_loss /= len(train_loader.dataset)
    l2_norm_total /= len(train_loader.dataset)
   
    # return acc and loss
    return train_acc, train_loss, total_loss.item(), l2_norm_total.item()

def train_with_hessian_layer_track(args, epoch, net, train_loader, taskloss, optimizer, scheduler):
    train_corr = 0
    train_loss = 0.
    l2_norm_total = 0.  # Initialize total loss with the hessian L2-norm
    total_loss = 0.
    layer_hessian_values = {name: [] for name, _ in net.named_parameters()}  # Initialize dict to store layer-wise Hessian losses

    # train...
    net.train()
    if scheduler: scheduler.step()
    for batch_idx, (data, labels) in enumerate(train_loader):
        if args.cuda:
            data, labels = data.cuda(), labels.cuda()
        data, labels = Variable(data), Variable(labels)
        optimizer.zero_grad()

        # Forward pass
        output = net(data)

        # compute task loss value
        tloss = taskloss(output, labels)
        bsize = data.size()[0]
        train_loss += (tloss.data.item() * bsize)

        # compute the accuracy
        predict = output.data.max(1, keepdim=True)[1]
        train_corr += predict.eq(labels.data.view_as(predict)).cpu().sum().item()

        # Compute first-order derivatives of the loss with respect to the model parameters
        gradients = torch.autograd.grad(tloss, net.parameters(), create_graph=True)

        l2_norm_total = 0.0  # Initialize total loss with the Hessian L2-norm

        # Compute layer-wise Hessian-vector product
        for name, param in net.named_parameters():
            if param.requires_grad:
                # Find the index of the parameter in the model's parameter list
                param_index = list(net.named_parameters()).index((name, param))

                # Flatten the gradients
                flat_grads = gradients[param_index].view(-1)

                # Create a random vector with the same shape as the flattened gradients
                random_vector = torch.randn_like(flat_grads)

                # Compute Hessian-vector product
                hvp = torch.autograd.grad(flat_grads, param, grad_outputs=random_vector, retain_graph=True)

                # Calculate the L2-norm of the HVP
                l2_norm_hvp = torch.sqrt(torch.sum(hvp[0]**2))

                layer_hessian_values[name].append(l2_norm_hvp)

                l2_norm_total += l2_norm_hvp

        # Combine the original task loss with the Hessian loss
        total_loss = tloss + args.hessLR * l2_norm_total

        # Backward pass for the combined loss
        total_loss.backward()

        # Perform optimization step using the gradients of the combined loss
        optimizer.step()

    # Update the losses
    train_loss /= len(train_loader.dataset)
    train_acc = 100. * train_corr / len(train_loader.dataset)
    total_loss /= len(train_loader.dataset)
    l2_norm_total /= len(train_loader.dataset)

    # Compute the average Hessian loss for each layer
    avg_layer_hessian_values = {name: torch.tensor(values).mean().item() for name, values in layer_hessian_values.items()}

    # Return accuracy, loss, total loss, average layer-wise Hessian losses
    return train_acc, train_loss, total_loss.item(), l2_norm_total.item(), avg_layer_hessian_values



def train_with_hessian_trace(args, epoch, net, train_loader, taskloss, optimizer, scheduler):
    train_corr = 0
    train_loss = 0.
    total_loss = 0.
    trace_hessian = 0.  # Initialize total trace of Hessian
    
    # if args.cuda:
    #     net.cuda()

    net.train()

    if scheduler:
        scheduler.step()

    for batch_idx, (data, labels) in enumerate(train_loader):
        if args.cuda:
            data, labels = data.cuda(), labels.cuda()
        data, labels = Variable(data), Variable(labels)
        optimizer.zero_grad()

        output = net(data)
        tloss = taskloss(output, labels)
        bsize = data.size()[0]
        train_loss += (tloss.data.item() * bsize)

        predict = output.data.max(1, keepdim=True)[1]
        train_corr += predict.eq(labels.data.view_as(predict)).cpu().sum().item()

        gradients = torch.autograd.grad(tloss, net.parameters(), create_graph=True)

        # Initialize the trace of the Hessian for this batch
        batch_trace_hessian = 0.0

        # Compute layer-wise Hessian-vector product and accumulate trace
        for name, param in net.named_parameters():
            if param.requires_grad:
                param_index = list(net.named_parameters()).index((name, param))
                flat_grads = gradients[param_index].view(-1)
                random_vector = torch.randn_like(flat_grads)
                hvp = torch.autograd.grad(flat_grads, param, grad_outputs=random_vector, retain_graph=True)
                hvp_flat = hvp[0].view(-1)
                batch_trace_hessian += torch.dot(random_vector.view(-1), hvp_flat)

        trace_hessian += batch_trace_hessian

        # Combine the original task loss with the Hessian trace loss
        total_loss = tloss + args.hessLR * batch_trace_hessian

        total_loss.backward()
        optimizer.step()
        

    # Update the losses and accuracy
    train_loss /= len(train_loader.dataset)
    train_acc = 100. * train_corr / len(train_loader.dataset)
    total_loss /= len(train_loader.dataset)
    trace_hessian /= len(train_loader.dataset)


    return train_acc, train_loss, total_loss.item(), trace_hessian.item()


def train_with_adaHessian(args, epoch, net, train_loader, taskloss, optimizer, scheduler):
    train_corr = 0
    train_loss = 0.

    net.train()
    if scheduler:
        scheduler.step()

    for batch_idx, (data, labels) in enumerate(train_loader):
        if args.cuda:
            data, labels = data.cuda(), labels.cuda()
        data, labels = Variable(data), Variable(labels)
        optimizer.zero_grad()

        output = net(data)
        tloss = taskloss(output, labels)
        bsize = data.size()[0]
        train_loss += (tloss.data.item() * bsize)

        predict = output.data.max(1, keepdim=True)[1]
        train_corr += predict.eq(labels.data.view_as(predict)).cpu().sum().item()

        tloss.backward(create_graph=True)
        optimizer.step()

    # Update the losses and accuracy
    train_loss /= len(train_loader.dataset)
    train_acc = 100. * train_corr / len(train_loader.dataset)

    return train_acc, train_loss


def valid(args, epoch, net, valid_loader, taskloss, store_paths):
    # test
    net.eval()

    # data holders
    valid_corr = 0
    valid_loss = 0.

    # loop over the test dataset
    for data, labels in tqdm(valid_loader, desc='[{}/{}]'.format(epoch, args.epoch)):
        if args.cuda:
            data, labels = data.cuda(), labels.cuda()
        data, labels = Variable(data, requires_grad=False), Variable(labels)
        with torch.no_grad():
            output = net(data)

            # compute loss and acc
            predict  = output.data.max(1, keepdim=True)[1]
            valid_corr += predict.eq(labels.data.view_as(predict)).cpu().sum().item()
            valid_loss += taskloss(output, labels, reduction='sum').data.item()

    # the total loss and accuracy
    valid_loss /= len(valid_loader.dataset)
    valid_acc   = 100. * valid_corr / len(valid_loader.dataset)

    # report the result
    print('  Epoch: {} [{}/{} (Acc: {:.4f}%)]\tAverage loss: {:.6f}'.format(
        epoch, valid_corr, len(valid_loader.dataset), valid_acc, valid_loss))

    # return acc and loss
    return valid_acc, valid_loss